% This file reproduces Figure 8 in the paper

close all
clear

global grid_size
global rho
global gamma_x
global gamma_s
global delta
global alpha
global c_x_1
global c_x_2
global c_s_1
global c_s_2
global x
global s
global x_dot_bar
global s_dot_bar
global phi_x
global phi_s

grid_size = 121;
rho = 1000;
delta = .1;
alpha = 2; % convexity parameter
x_dot_bar = 15-delta; % upper bound on x_dot
s_dot_bar = 10-delta; % upper bound on s_dot

error_tol = 10e-12;

% Output
phi_x = 0;
phi_s = 1;

% cost functions 
% c(z) = c_1*z+c_2*z^alpha

c_x_1 = 0;
c_s_1 = 0;

% Assumption 3
gamma_x = min(.4*(PDF(0)-PDF(1)),1)
gamma_s = min(.35*(PDF(0)-PDF(1)),1)

% Assumption 3
convex_weight_x = 0;
convex_weight_s = 0;
c_x_2 = (convex_weight_x*PDF(1) + (1-convex_weight_x)*min(PDF(0)-gamma_x,PDF(gamma_x)))/(alpha*delta^(alpha-1))
c_s_2 = (convex_weight_s*PDF(1) + (1-convex_weight_s)*min(PDF(0)-gamma_s,PDF(gamma_s)))/(alpha*delta^(alpha-1))

% the grid
x = linspace(0,1,grid_size)'*ones(1,grid_size);
s = ones(grid_size,1)*linspace(0,1,grid_size);

% initial value and policy functions
x_dot = zeros(grid_size,grid_size);
s_dot = zeros(grid_size,grid_size);
V_x = zeros(grid_size,grid_size);
V_s = zeros(grid_size,grid_size);

error = error_tol*10;

while error > error_tol
    [dV_x_over_ds,dV_x_over_dx] = gradient(V_x,1/(grid_size-1)/2,1/(grid_size-1)/2);
    [dV_s_over_ds,dV_s_over_dx] = gradient(V_s,1/(grid_size-1)/2,1/(grid_size-1)/2);
     
    x_dot_new = x_dot_max(dV_x_over_dx);
    s_dot_new = s_dot_max(dV_s_over_ds);
    
    V_x_new = CDF(x-s).*(1+phi_x.*x+phi_s.*s)+(1/rho)*(-C_x(x_dot_new)+dV_x_over_dx.*x_dot_new+dV_x_over_ds.*s_dot_new);
    V_s_new = CDF(s-x).*(1+phi_x.*x+phi_s.*s)+(1/rho)*(-C_s(s_dot_new)+dV_s_over_dx.*x_dot_new+dV_s_over_ds.*s_dot_new);
    
    error = norm(V_x_new-V_x,Inf)+norm(V_s_new-V_s,Inf)+norm(x_dot_new-x_dot,Inf)+norm(s_dot_new-s_dot,Inf);
    
    V_x = V_x_new;
    V_s = V_s_new;
    x_dot = x_dot_new;
    s_dot = s_dot_new;
end
x_dot(1,:) = max(x_dot(1,:),zeros(1,grid_size));
s_dot(:,1) = max(s_dot(:,1),zeros(grid_size,1));
x_dot(grid_size,:) = min(x_dot(1,:),zeros(1,grid_size));
s_dot(:,grid_size) = min(s_dot(:,1),zeros(grid_size,1));

figure
axis equal
axis tight
hold on

%streamslice(s(1,:),x(:,1),s_dot,x_dot)

yy = linspace(0,1,500);
xx = 0.31+1.65*yy-.1*yy.^2;
indices = ((xx<=1) & (xx>=0));
plot(yy(indices),xx(indices),'r','LineWidth',1.1)

yy = linspace(0,1,500);
xx = -0.24+.67*yy-0.045*yy.^2;
indices = ((xx<=1) & (xx>=0));
plot(yy(indices),xx(indices),'r','LineWidth',1.1)

yy = linspace(0,1,500);
xx = 0.45+1.65*yy-.4*yy.^2;
indices = ((xx<=1) & (xx>=0));
plot(yy(indices),xx(indices),'g','LineWidth',1.1)

yy = linspace(0,1,500);
xx = -0.12+.63*yy-0.21*yy.^2;
indices = ((xx<=1) & (xx>=0));
plot(yy(indices),xx(indices),'g','LineWidth',1.1)

x_dot_sub = zeros(grid_size,grid_size);
s_dot_sub = zeros(grid_size,grid_size);
for ss=1:grid_size
    for xx=1:grid_size
        if x(ss,xx)>0.45+1.65*s(ss,xx)-.4*s(ss,xx).^2               
            x_dot_sub(ss,xx) = x_dot(ss,xx);
            s_dot_sub(ss,xx) = s_dot(ss,xx);
        end
    end
end
h = streamslice(s(1,:),x(:,1),s_dot_sub,x_dot_sub);
set(h, 'Color', 'm')

x_dot_sub = zeros(grid_size,grid_size);
s_dot_sub = zeros(grid_size,grid_size);
for ss=1:grid_size
    for xx=1:grid_size
        if x(ss,xx)<-0.12+.63*s(ss,xx)-0.21*s(ss,xx).^2               
            x_dot_sub(ss,xx) = x_dot(ss,xx);
            s_dot_sub(ss,xx) = s_dot(ss,xx);
        end
    end
end
h = streamslice(s(1,:),x(:,1),s_dot_sub,x_dot_sub);
set(h, 'Color', 'b')

x_dot_sub = zeros(grid_size,grid_size);
s_dot_sub = zeros(grid_size,grid_size);
for ss=1:grid_size
    for xx=1:grid_size
        if (x(ss,xx)<0.45+1.57*s(ss,xx)-.4*s(ss,xx).^2) && (x(ss,xx)>-0.12+.63*s(ss,xx)-0.21*s(ss,xx).^2)               
            x_dot_sub(ss,xx) = x_dot(ss,xx);
            s_dot_sub(ss,xx) = s_dot(ss,xx);
        end
    end
end
h = streamslice(s(1,:),x(:,1),s_dot_sub,x_dot_sub);
set(h, 'Color', 'k')

axis([-0.02 1.02 -0.02 1.02])
xlabel('x')
ylabel('s')
set(get(gca,'YLabel'),'Rotation',0)
print('figure8','-depsc','-painters')